# -*- coding: utf-8 -*-
"""NC_SELI_Imbalance.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1xg4LDSrI0tjrmHOZhPiTrELLzf9Sr8hE
"""

import os
import sys
import pickle
import shutil

import torch

import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.models as models

from tqdm import tqdm
from collections import OrderedDict
from scipy.sparse.linalg import svds
from torchvision import datasets, transforms

from generate_cifar import IMBALANCECIFAR10
from generate_mnist import IMBALANCEMNIST


#-------  analysis ---------------------------------------
class graphs:
  def __init__(self):
    self.cur_epochs     = []
    self.accuracy     = []
    self.loss         = []
    self.reg_loss     = []

    self.test_loss = []
    self.test_acc = []

    # NC1
    self.Sw_invSb     = []

    # NC2
    self.norm_M_CoV   = []
    self.norm_W_CoV   = []
    self.cos_M        = []
    self.cos_W        = []

    # NC3
    self.W_M_dist     = []
    
    # NC4
    self.NCC_mismatch = []

    # Decomposition
    self.MSE_wd_features = []
    self.LNC1 = []
    self.LNC23 = []
    self.Lperp = []


#------- train fcn ---------------------------------------
def train(model, criterion, device, num_classes, train_loader, optimizer, epoch, n_c_train_target):
    model.train()

    per_class_acc = {}
    for c in range(0, num_classes):
        per_class_acc[c] = 0
    
    pbar = tqdm(total=len(train_loader), position=0, leave=True)
    for batch_idx, (data, target) in enumerate(train_loader, start=1):
        if data.shape[0] != batch_size:
            continue
        
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, target)
        
        predicted = torch.argmax(out, dim=1)
        loss.backward()
        optimizer.step()

        accuracy = torch.mean((torch.argmax(out,dim=1)==target).float()).item()

        pbar.update(1)
        pbar.set_description(
            'Train\t\tEpoch: {} [{}/{} ({:.0f}%)] \t'
            'Batch Loss: {:.6f} \t'
            'Batch Accuracy: {:.6f}'.format(
                epoch,
                batch_idx,
                len(train_loader),
                100. * batch_idx / len(train_loader),
                loss.item(),
                accuracy))
        
        for c in range(0, num_classes):
            per_class_acc[c] += ((predicted == target) * (target == c)).sum().item()

        if debug and batch_idx > 20:
          break

    for c in range(0, num_classes):
        per_class_acc[c] /= n_c_train_target[c]
    print("Training per_class_acc: " + str(per_class_acc))
        
    pbar.close()

    return per_class_acc



#------- analysis fcn ---------------------------------------
def analysis(graph, model, criterion_summed, device, num_classes, loader, test_loader, NC_analysis=False, cls_num_list = None, features = None, epoch = None, classifier = None, cls_num_list_test = None):
    model.eval()

    N             = [0 for _ in range(C)]
    mean          = [0 for _ in range(C)]
    Sw            = 0

    mu_c_save = None

    loss          = 0
    net_correct   = 0
    NCC_match_net = 0

    if NC_analysis:
      
      for computation in ['Mean','Cov']:
        pbar = tqdm(total=len(loader), position=0, leave=True)
        for batch_idx, (data, target) in enumerate(loader, start=1):

          data, target = data.to(device), target.to(device)

          output = model(data)
          h = features.value.data.view(data.shape[0],-1) # B CHW
          
          # during calculation of class means, calculate loss
          if computation == 'Mean':
            if str(criterion_summed) == 'CrossEntropyLoss()':
              loss += criterion_summed(output, target).item()
            elif str(criterion_summed) == 'MSELoss()':
              loss += criterion_summed(output, F.one_hot(target, num_classes=num_classes).float()).item()

          for c in range(C):
            # features belonging to class c
            idxs = (target == c).nonzero(as_tuple=True)[0]
            
            if len(idxs) == 0: # If no class-c in this batch
              continue

            h_c = h[idxs,:] # B CHW

            if computation == 'Mean':
              # update class means
              mean[c] += torch.sum(h_c, dim=0) # CHW
              N[c] += h_c.shape[0]
              
            elif computation == 'Cov':
              # update within-class cov

              z = h_c - mean[c].unsqueeze(0) # B CHW
              cov = torch.matmul(z.unsqueeze(-1), # B CHW 1
                          z.unsqueeze(1))  # B 1 CHW
              Sw += torch.sum(cov, dim=0)

              # during calculation of within-class covariance, calculate:
              # 1) network's accuracy
              net_pred = torch.argmax(output[idxs,:], dim=1)
              net_correct += sum(net_pred==target[idxs]).item()

              # 2) agreement between prediction and nearest class center
              NCC_scores = torch.stack([torch.norm(h_c[i,:] - M.T,dim=1) \
                            for i in range(h_c.shape[0])])
              NCC_pred = torch.argmin(NCC_scores, dim=1)
              NCC_match_net += sum(NCC_pred==net_pred).item()

          pbar.update(1)
          pbar.set_description(
            'Analysis {}\t'
            'Epoch: {} [{}/{} ({:.0f}%)]'.format(
              computation,
              epoch,
              batch_idx,
              len(loader),
              100. * batch_idx/ len(loader)))
          
          if debug and batch_idx > 20:
            break
        pbar.close()
        
        if computation == 'Mean':
          for c in range(C):
            mean[c] /= N[c]
            M = torch.stack(mean).T
          loss /= sum(N)
          mu_c_save = mean
        elif computation == 'Cov':
          Sw /= sum(N)

      graph.loss.append(loss)
      graph.accuracy.append(net_correct/sum(N))
      graph.NCC_mismatch.append(1-NCC_match_net/sum(N))

      # loss with weight decay
      reg_loss = loss
      for param in model.parameters():
        reg_loss += 0.5 * weight_decay * torch.sum(param**2).item()
      graph.reg_loss.append(reg_loss)

      # global mean
      muG = torch.mean(M, dim=1, keepdim=True) # CHW 1

      # between-class covariance
      M_ = M - muG
      Sb = torch.matmul(M_, M_.T) / C

      # avg norm
      W  = classifier.weight
      M_norms = torch.norm(M_,  dim=0)
      W_norms = torch.norm(W.T, dim=0)

      graph.norm_M_CoV.append((torch.std(M_norms)/torch.mean(M_norms)).item())
      graph.norm_W_CoV.append((torch.std(W_norms)/torch.mean(W_norms)).item())

      # tr{Sw Sb^-1}
      Sw = Sw.cpu().numpy()
      Sb = Sb.cpu().numpy()
      eigvec, eigval, _ = svds(Sb, k=C-1)
      inv_Sb = eigvec @ np.diag(eigval**(-1)) @ eigvec.T 
      graph.Sw_invSb.append(np.trace(Sw @ inv_Sb))

      # ||W^T - M_||
      normalized_M = M_ / torch.norm(M_,'fro')
      normalized_W = W.T / torch.norm(W.T,'fro')
      graph.W_M_dist.append((torch.norm(normalized_W - normalized_M)**2).item())

      # mutual coherence
      def coherence(V): 
        G = V.T @ V
        G += torch.ones((C,C),device=device) / (C-1)
        G -= torch.diag(torch.diag(G))
        return torch.norm(G,1).item() / (C*(C-1))

      graph.cos_M.append(coherence(M_/M_norms))
      graph.cos_W.append(coherence(W.T/W_norms))

		
	# test error
    correct = 0
    total = 0
    test_loss = 0

    per_class_acc = {}
    for c in range(0, num_classes):
        per_class_acc[c] = 0
    
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader, start=1):

            data, target = data.to(device), target.to(device)

            output = model(data)

            if str(criterion_summed) == 'CrossEntropyLoss()':
              test_loss += criterion_summed(output, target).item()
            elif str(criterion_summed) == 'MSELoss()':
              test_loss += criterion_summed(output, F.one_hot(target, num_classes=num_classes).float()).item()

            predicted = torch.argmax(output, dim=1)
            correct += sum(predicted==target).item()

            for c in range(0, num_classes):
                per_class_acc[c] += ((predicted == target) * (target == c)).sum().item()

        test_loss /= len(test_loader.dataset)
        acc = correct / len(test_loader.dataset)

        for c in range(0, num_classes):
            per_class_acc[c] /= cls_num_list_test[c]

    graph.test_loss.append(test_loss)
    graph.test_acc.append(acc)

    print(f'Test accuracy: {100 * acc} %')
    print("Testing per_class_acc: " + str(per_class_acc))

    return mu_c_save, per_class_acc

###########################################################################################################################



#------- parameters ---------------------------------------
debug = False # Only runs 20 batches per epoch for debugging

# dataset parameters
im_size             = 32
padded_im_size      = 32
C                   = 10

# Optimization Criterion
loss_name = 'CrossEntropyLoss'

# Optimization hyperparameters
lr_decay            = 0.1

dataset_name = "cifar10"

epochs              = 350
epochs_lr_decay     = [epochs//3, epochs*2//3]

batch_size          = 128

momentum            = 0.9
weight_decay        = 5e-4



def run():

  root_path = "./Saved_Training_Data/"
  if dataset_name == "cifar10":
    input_ch            = 3
    root_path += "cifar10/"
  elif dataset_name == "mnist":
    input_ch            = 1
    root_path += "mnist/"
  
  for R in [1,5,10,100]:
    
    print("_" * 50)
    print("Experiment for Ratio: " + str(R))
    
    experiment_name = "R_" + str(R)
    save_log_path = root_path + experiment_name + "/"

    experiment_complete_flag_file = save_log_path + "/ExpComplete.txt"
    print("save_log_path: " + str(save_log_path))
    if not os.path.exists(save_log_path):
        os.makedirs(save_log_path, exist_ok=True)
    elif not os.path.exists(experiment_complete_flag_file):
        shutil.rmtree(save_log_path)
        os.makedirs(save_log_path, exist_ok=True)
    else:
        print("Skipping this experiments, already done ...")
        continue

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device: ", device)

    imb_type = 'step'
    imb_factor = 1/R
    rand_number = 0
    workers = 4
    train_sampler = None
    NC_analysis = True

    # analysis parameters
    epoch_list          = [1,   2,   3,   4,   5,   6,   7,   8,   9,   10,   11,
                          12,  13,  14,  16,  17,  19,  20,  22,  24,  27,   29,
                          32,  35,  38,  42,  45,  50,  54,  59,  65,  71,   77,
                          85,  92,  101, 110, 121, 132, 144, 158, 172, 188,  206,
                          225, 245, 268, 293, 320, 350]

    # output parameters
    data_path = 'data/' 
    save_path = '' 
    model_path = 'output/model.pt' 
    file_path = ''
    fig_path = 'output/'
    # PATH = 'output/model.pt'
    
    N_maj_dict = {1:2525, 2:3366, 5: 4208, 10:4591, 20: 4809, 50: 4950, 100:5000}
    N_min_dict = {1:2525, 2:1683, 5: 841, 10:459, 20: 240, 50: 99, 100:50}
    N_test = 1000
    maj_classes = [0,1,2,3,4]
    min_classes = [5,6,7,8,9]
    classes = maj_classes + min_classes

    N_maj = N_maj_dict[R]
    N_min = N_min_dict[R]

    n_c_train_target = {}
    for c in maj_classes:
        n_c_train_target[c] = N_maj
    for c in min_classes:
        n_c_train_target[c] = N_min
    print("n_c_target: " + str(n_c_train_target))
    N_train_total = sum(n_c_train_target.values())
    print("N_train_total: " + str(N_train_total))
    

    #-------  model ---------------------------------------
    model = models.resnet18(pretrained=False, num_classes=C)
    model.conv1 = nn.Conv2d(input_ch, model.conv1.weight.shape[0], 3, 1, 1, bias=False) # Small dataset filter size used by He et al. (2015)
    model.maxpool = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)
    model.fc = nn.Linear(in_features=512, out_features=10, bias=False)
    model = model.to(device)

    class features:
        pass

    def hook(self, input, output):
        features.value = input[0].clone()

    # register hook that saves last-layer input into features
    classifier = model.fc
    classifier.register_forward_hook(hook)



    #-------  dataset imbalance -------------------------------------

    if dataset_name == "cifar10":
      transform_train = transforms.Compose([
          transforms.RandomCrop(32, padding=4),
          transforms.RandomHorizontalFlip(),
          transforms.ToTensor(),
          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
      ])

      transform_val = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
      ])



      train_dataset = IMBALANCECIFAR10(root='./class_imbalance/data', imb_type="step", imb_factor=R,
                                            rand_number=1, train=True, download=True,
                                            transform=transform_train, n_c_train_target = n_c_train_target, classes = classes)
      val_dataset = datasets.CIFAR10(root='./class_imbalance/data', train=False, download=True, transform=transform_val)

      cls_num_list_test = {}
      for c in range(0, C):
          cls_num_list_test[c] = 1000
      
    elif dataset_name == "mnist":
      transform = transforms.Compose([transforms.Pad((padded_im_size - im_size)//2),
                                transforms.ToTensor(),
                                transforms.Normalize(0.1307,0.3081)])



      train_dataset = IMBALANCEMNIST(root='./class_imbalance/data', imb_type="step", imb_factor=R,
                                            rand_number=1, train=True, download=True,
                                            transform=transform, n_c_train_target = n_c_train_target, classes = classes)
      val_dataset = datasets.MNIST(root='./class_imbalance/data', train=False, download=True, transform=transform)
    
      cls_num_list_test = {}
      for c in range(0, C):
          cls_num_list_test[c] = 0
      for label in val_dataset.targets:
          cls_num_list_test[label.item()] += 1
    
      train_dataset.data = torch.tensor(train_dataset.data)
    print("+" * 10)
    print("cls_num_list_test: " + str(cls_num_list_test))
    print("+" * 10)
        

    cls_num_list = train_dataset.get_cls_num_list()
    cls_priors = [cls_num / sum(cls_num_list) for cls_num in cls_num_list]
    print('\nTotal number of samples: ', sum(cls_num_list))
    print('cls num list:')
    print(cls_num_list)


    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=True, sampler=train_sampler)

    analysis_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=True, sampler=train_sampler)

    test_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True)


    #-------  optimizer ---------------------------------------
    criterion = nn.CrossEntropyLoss()
    criterion_summed = nn.CrossEntropyLoss(reduction='sum')


    # Best lr after hyperparameter tuning
    if dataset_name == "mnist":
      lr = 0.0679
    elif dataset_name == "cifar10":
      lr = 1e-1
    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay)

    # Optimization hyperparameters
    lr_decay            = 0.1
    epochs_lr_decay     = [epochs//3, epochs*2//3]
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                  milestones=epochs_lr_decay,
                                                  gamma=lr_decay)
      
      
    graph = graphs()

    cur_epochs = []

    mu_c_list_train = []
    W_list = []
    B_list = []
    train_accuracies_list = []
    test_accuracies_list = []

    for epoch in range(1, epochs + 1):
        per_class_acc_train = train(model, criterion, device, C, train_loader, optimizer, epoch, n_c_train_target)
        lr_scheduler.step()
        
        if epoch in epoch_list:
            cur_epochs.append(epoch)
            mu_c_save, per_class_acc_test = analysis(graph, model, criterion_summed, device, C, analysis_loader, test_loader, NC_analysis=NC_analysis, cls_num_list = cls_num_list, features = features, epoch = epoch, classifier = classifier, cls_num_list_test = cls_num_list_test)
          
            graph.cur_epochs = cur_epochs
            f1 = open(save_log_path+'graphs_save.pkl', "wb")
            pickle.dump(graph, f1)
            f1.close()

            W = classifier.weight.to("cpu")
            Bias = classifier.bias
            if Bias == None:
                Bias = torch.zeros((C), requires_grad=True).to("cpu")
            else:
                Bias = Bias.to("cpu")
            W_list.append(W)
            B_list.append(Bias)

            mu_c_list_train.append(mu_c_save)
            train_accuracies_list.append(per_class_acc_train)
            test_accuracies_list.append(per_class_acc_test)

            print(f'Checkpoint saved. Epoch: {epoch} ')
        

    torch.save(mu_c_list_train, save_log_path + "mu_c_list_train")
    torch.save(W_list, save_log_path + "W_list")
    torch.save(B_list, save_log_path + "B_list")
    torch.save(train_accuracies_list, save_log_path + "train_accuracies_list")
    torch.save(test_accuracies_list,  save_log_path + "test_accuracies_list")


    mu_c_list_train = torch.load(save_log_path + "mu_c_list_train")
    W_list = torch.load(save_log_path + "W_list")
    B_list = torch.load(save_log_path + "B_list")
    train_accuracies_list = torch.load(save_log_path + "train_accuracies_list")
    test_accuracies_list = torch.load(save_log_path + "test_accuracies_list")


    os.makedirs(experiment_complete_flag_file, exist_ok=True)

run()